{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1>Federated Word Vectors</h1>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This example demonstrates how word vector model PyTorch could be trained using federated learning with PySyft. We distribute the text data to two workers Bob and Alice to whom the model is sent and trained. Upon training the model the trained model is sent back to the owner of the model and used to make predictions or the embedding layer which consist of learnt word vectors could be used. Federated learning applied to word vectors could be a great way to analyze textual data without knowing the specifics of the text and risk invading privacy. In a real-time application , say understanding internal e-mail culture of a organization. In this example we learn a word embedding by trying to predict the next word given context of N words.\n",
    "\n",
    "Hrishikesh Kamath - GitHub: @<a href=\"http://github.com/kamathhrishi\">kamathhrishi</a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#Import modules required for PyTorch Neural Networks\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "from torch.utils.data import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n",
    "# Shakespeare Sonnet 2 as text to be learned \n",
    "\n",
    "dataset = \"\"\"When forty winters shall besiege thy brow,\n",
    "And dig deep trenches in thy beauty's field,\n",
    "Thy youth's proud livery so gazed on now,\n",
    "Will be a totter'd weed of small worth held:\n",
    "Then being asked, where all thy beauty lies,\n",
    "Where all the treasure of thy lusty days;\n",
    "To say, within thine own deep sunken eyes,\n",
    "Were an all-eating shame, and thriftless praise.\n",
    "How much more praise deserv'd thy beauty's use,\n",
    "If thou couldst answer 'This fair child of mine\n",
    "Shall sum my count, and make my old excuse,'\n",
    "Proving his beauty by succession thine!\n",
    "This were to be new made when thou art old,\n",
    "And see thy blood warm when thou feel'st it cold.\"\"\".split()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Arguments():\n",
    "    def __init__(self):\n",
    "        self.batch_size = 1\n",
    "        self.test_batch_size = 1000\n",
    "        self.epochs = 10\n",
    "        self.lr = 0.01\n",
    "        self.momentum = 0.5 #<-We currenly do not support momentum\n",
    "        self.no_cuda = False\n",
    "        self.seed = 1\n",
    "        self.log_interval = 10\n",
    "        self.save_model = False\n",
    "        self.context_size=3\n",
    "        self.embedding_dim=10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "args=Arguments()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x117c60790>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Define seed to maintain consistency \n",
    "torch.manual_seed(args.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#Import PySyft library required for federated learning\n",
    "import syft as sy  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#Define Syft workers Bob and Alice for federated learning\n",
    "\n",
    "hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")  # <-- NEW: define remote worker bob\n",
    "alice = sy.VirtualWorker(hook, id=\"alice\")  # <-- NEW: and alice\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#vocabulary of from the corpus\n",
    "vocab = set(dataset)\n",
    "word_to_ix = {word: i for i, word in enumerate(vocab)}\n",
    "ix_to_word={word_to_ix[word]:word for word in word_to_ix}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Torch Dataset</h2>\n",
    "Convert text dataset into a torch dataset instance which we will need to create a federated dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n",
    "class TextDataset(Dataset):\n",
    "\n",
    "    def __init__(self,text,transform=None):\n",
    "        \n",
    "        \"\"\"arguments:\n",
    "        \n",
    "             text (List of Strings): Text corpus \n",
    "             transform: List of transforms to be performed on the input data\n",
    "             \n",
    "        \"\"\"\n",
    "\n",
    "        self.text = text\n",
    "        self.data=[]\n",
    "        self.targets=[]\n",
    "        self.transform = transform\n",
    "        \n",
    "        #Create Trigrams \n",
    "        self.create_context()\n",
    "\n",
    "    def __len__(self):\n",
    "        \n",
    "        return len(self.data)\n",
    "    \n",
    "    def create_context(self):\n",
    "        \n",
    "        '''Function used to seperate target and context words and convert them to torch tensors'''\n",
    "        \n",
    "        context=[]\n",
    "        \n",
    "        for i in range(len(self.text)-args.context_size):\n",
    "            \n",
    "            vec=[]\n",
    "            \n",
    "            for j in range(0,args.context_size):\n",
    "                \n",
    "                vec.append(self.text[i+j])\n",
    "                \n",
    "            context.append([vec,self.text[i+args.context_size]])\n",
    "                \n",
    "        \n",
    "        for words,target in context:\n",
    "            \n",
    "            tensor=torch.tensor([word_to_ix[w] for w in words],dtype=torch.long)\n",
    "            self.data.append(tensor)\n",
    "            self.targets.append(torch.tensor([word_to_ix[target]], dtype=torch.long))\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        \n",
    "        sample=self.data[idx]\n",
    "        target=self.targets[idx]\n",
    "                \n",
    "        if self.transform:\n",
    "            sample = self.transform(sample)\n",
    "\n",
    "        return sample,target"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use federated data loader to distribute dataset to workers. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Scanning and sending data to bob, alice...\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader \n",
    "                         TextDataset(dataset)\n",
    "                         .federate((bob, alice)),batch_size=args.batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Neural Network Model</h2>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define Neural Network in PyTorch. The network is trained to predict the next word based on given context. Based on the trained model the required embedding is learnt. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class NGramLanguageModeler(nn.Module):\n",
    "\n",
    "    def __init__(self, vocab_size, embedding_dim, context_size):\n",
    "        \n",
    "        super(NGramLanguageModeler, self).__init__()\n",
    "        self.embeddings = nn.Embedding(vocab_size,embedding_dim)\n",
    "        self.linear1 = nn.Linear(context_size * embedding_dim, 128)\n",
    "        self.linear2 = nn.Linear(128, vocab_size)\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        \n",
    "        embeds = self.embeddings(inputs).view((1, -1))\n",
    "        out = F.relu(self.linear1(embeds))\n",
    "        out = self.linear2(out)\n",
    "        log_probs = F.log_softmax(out,dim=1)\n",
    "        \n",
    "        return log_probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "loss_function = nn.NLLLoss()\n",
    "model = NGramLanguageModeler(len(vocab),args.embedding_dim,args.context_size)\n",
    "optimizer = optim.SGD(model.parameters(), lr=args.lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Train Model</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train():\n",
    "    \n",
    "    model.train()\n",
    "    iteration=0\n",
    "    for context, target in federated_train_loader:\n",
    "        \n",
    "        model.send(context.location)\n",
    "        # Step 1. Prepare the inputs to be passed to the model (i.e, turn the words\n",
    "        # into integer indices and wrap them in tensors)\n",
    "        context_idxs = context\n",
    "    \n",
    "        # Step 2. Recall that torch *accumulates* gradients. Before passing in a\n",
    "        # new instance, you need to zero out the gradients from the old\n",
    "        # instance\n",
    "        model.zero_grad()\n",
    "\n",
    "        # Step 3. Run the forward pass, getting log probabilities over next\n",
    "        # words\n",
    "        log_probs = model(context_idxs)\n",
    "\n",
    "        # Step 4. Compute your loss function. (Again, Torch wants the target\n",
    "        # word wrapped in a tensor)\n",
    "        loss = loss_function(log_probs,target[0])\n",
    "\n",
    "        # Step 5. Do the backward pass and update the gradient\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        model.get()\n",
    "        \n",
    "        # Get the Python number from a 1-element Tensor by calling tensor.item()\n",
    "        # The loss decreased every iteration over the training data!\n",
    "        iteration+=1\n",
    "        if(iteration%100==0):\n",
    "            \n",
    "            print(loss.get().item())\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/hrishikesh/anaconda3/envs/syft_1/lib/python3.6/site-packages/syft/frameworks/torch/tensors/interpreters/native.py:215: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  response = eval(cmd)(*args, **kwargs)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.5544867515563965\n",
      "EPOCH:  1\n",
      "4.0127153396606445\n",
      "EPOCH:  2\n",
      "3.4864673614501953\n",
      "EPOCH:  3\n",
      "2.9545602798461914\n",
      "EPOCH:  4\n",
      "2.4121615886688232\n",
      "EPOCH:  5\n",
      "1.876990556716919\n",
      "EPOCH:  6\n",
      "1.40726637840271\n",
      "EPOCH:  7\n",
      "1.042470932006836\n",
      "EPOCH:  8\n",
      "0.7883358001708984\n",
      "EPOCH:  9\n",
      "0.6222090721130371\n",
      "EPOCH:  10\n"
     ]
    }
   ],
   "source": [
    "\n",
    "for epoch in range(0,args.epochs):\n",
    "    train()\n",
    "    print(\"EPOCH: \",epoch+1)        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "if (args.save_model):\n",
    "    torch.save(model.state_dict(), \"word_vector.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Visualize Results</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def SimilarPairs(model,vocab,inverse_vocab):\n",
    "    \n",
    "   #Function to compute the most similar pairs\n",
    "   \n",
    "   matrix=[]\n",
    "\n",
    "   for ref_index in range(0,len(vocab)):\n",
    "    \n",
    "      Max=-10.0\n",
    "      Index=0\n",
    "      \n",
    "      ref=model.embeddings(torch.LongTensor([ref_index]))\n",
    "      for i in range(0,len(vocab)):\n",
    "   \n",
    "           cos = nn.CosineSimilarity(dim=1, eps=1e-6)\n",
    "           output = cos(ref,model.embeddings(torch.LongTensor([i])))\n",
    "            \n",
    "           if(output.item()>Max and i!=ref_index):\n",
    "             \n",
    "             Max=output.item()\n",
    "             Index=i\n",
    "                \n",
    "      matrix.append([ix_to_word[ref_index],ix_to_word[Index],Max])\n",
    "    \n",
    "    \n",
    "   return(matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "similar_Pairs=SimilarPairs(model,word_to_ix,ix_to_word)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The word vectors learnt don't exactly capture meanings of actual words since it was trained on a smaller corpora. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['use,', 'And', 0.7608605623245239],\n",
       " ['more', 'were', 0.9239389896392822],\n",
       " ['count,', 'small', 0.597095251083374],\n",
       " ['say,', 'within', 0.6586640477180481],\n",
       " ['forty', 'mine', 0.6533328890800476],\n",
       " ['dig', 'see', 0.7790502309799194],\n",
       " ['treasure', 'lies,', 0.5748651027679443],\n",
       " [\"youth's\", 'child', 0.8469692468643188],\n",
       " ['mine', 'gazed', 0.7669305801391602],\n",
       " ['small', 'thy', 0.7145660519599915],\n",
       " ['If', \"deserv'd\", 0.6726179718971252],\n",
       " ['his', 'child', 0.6212238669395447],\n",
       " ['days;', 'praise.', 0.8395010232925415],\n",
       " ['worth', 'now,', 0.7184102535247803],\n",
       " ['sunken', 'shame,', 0.7679064869880676],\n",
       " ['held:', \"totter'd\", 0.6357218027114868],\n",
       " ['be', 'sum', 0.780488133430481],\n",
       " ['To', 'thine!', 0.7006815075874329],\n",
       " ['weed', 'my', 0.802433967590332]]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Similar pairs of first 20 words\n",
    "similar_Pairs[1:20]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Well Done!\n",
    "\n",
    "And voilà! We now are training a real world Learning model using Federated Learning! \n",
    "\n",
    "## Shortcomings of this Example\n",
    "\n",
    "Of course, there are dozen of improvements we could think of. We would like the computation to operate in parallel on the workers, to update the central model every `n` batches only, to reduce the number of messages we use to communicate between workers, etc.\n",
    "\n",
    "On the security side it still has some major shortcomings. Most notably, when we call `model.get()` and receive the updated model from Bob or Alice, we can actually learn a lot about Bob and Alice's training data by looking at their gradients. We could **average the gradient across multiple individuals before uploading it to the central server**, like we did in Part 4.\n",
    "\n",
    "The above embeddings are not useful for practical purposes as they are trained on a very small corpus. Increasing corpus size could lead to more useful embeddings."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "\n",
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the repositories! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Pick our tutorials on GitHub!\n",
    "\n",
    "We made really nice tutorials to get a better understanding of what Federated and Privacy-Preserving Learning should look like and how we are building the bricks for this to happen.\n",
    "\n",
    "- [Checkout the PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)\n",
    "\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! \n",
    "\n",
    "- [Join slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! If you want to start \"one off\" mini-projects, you can go to PySyft GitHub Issues page and search for issues marked `Good First Issue`.\n",
    "\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "- [Donate through OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
